{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "bag_of_logistics.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyO1DdnQCrCvfJuwTSq0nzyK",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/EugenHotaj/nn-hallucinations/blob/master/bag_of_logistics.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "T2gkAhA2Yzwp",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 153
        },
        "outputId": "93f1898b-6241-4e0b-e43e-6ac003458c80"
      },
      "source": [
        "!git clone https://www.github.com/EugenHotaj/nn-hallucinations nn_hallucinations"
      ],
      "execution_count": 92,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'nn_hallucinations'...\n",
            "warning: redirecting to https://github.com/EugenHotaj/nn-hallucinations.git/\n",
            "remote: Enumerating objects: 108, done.\u001b[K\n",
            "remote: Counting objects: 100% (108/108), done.\u001b[K\n",
            "remote: Compressing objects: 100% (97/97), done.\u001b[K\n",
            "remote: Total 108 (delta 56), reused 26 (delta 11), pack-reused 0\u001b[K\n",
            "Receiving objects: 100% (108/108), 13.46 MiB | 8.00 MiB/s, done.\n",
            "Resolving deltas: 100% (56/56), done.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fhbHjHAGuIQY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "from nn_hallucinations import colab_utils\n",
        "import numpy as np\n",
        "import torch\n",
        "from torch import distributions\n",
        "from torch import nn\n",
        "from torch import optim\n",
        "from torch.nn import functional as F\n",
        "from torch.utils import data\n",
        "from torchvision import datasets\n",
        "from torchvision import transforms"
      ],
      "execution_count": 93,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PKzRg91VuUTM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "BATCH_SIZE=512\n",
        "\n",
        "transform = transforms.Compose([\n",
        "    transforms.ToTensor(),\n",
        "    lambda x: x.view(-1),\n",
        "    lambda x: distributions.Bernoulli(probs=x).sample()])\n",
        "\n",
        "trainloader = data.DataLoader(\n",
        "    datasets.MNIST(\n",
        "        root='../data', train=True, download=True, transform=transform),\n",
        "    batch_size=BATCH_SIZE, \n",
        "    num_workers=2)\n",
        " \n",
        "testloader = data.DataLoader(\n",
        "    datasets.MNIST(\n",
        "        root='../data', train=False, download=True, transform=transform),\n",
        "    batch_size=BATCH_SIZE, \n",
        "    num_workers=2)"
      ],
      "execution_count": 94,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "770LEzybMrNU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class MeanLayer(nn.Module):\n",
        "  \"\"\"Linear layer which only models the mean of the labels.\n",
        "\n",
        "  This is achieved by only having a bias parameter which does not depend on \n",
        "  any data. N.B. that the dimension of the bias is hardcoded to be 1.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    super().__init__()\n",
        "    self._bias = torch.nn.Parameter(data=torch.zeros((1,)))\n",
        "\n",
        "  def forward(self, x):\n",
        "    \"\"\"Computes the forward pass.\n",
        "    \n",
        "    IMPORTANT: The input is not used in the forward pass computation but must\n",
        "    be provided to determine the batch size.\n",
        "    \"\"\"\n",
        "    batch_size = x.shape[0]\n",
        "    return torch.ones((batch_size, 1)) * self._bias\n",
        "\n",
        "class BagOfLogistics(nn.Module):\n",
        "  \"\"\".\"\"\"\n",
        "\n",
        "  def __init__(self, features):\n",
        "    super().__init__()\n",
        "\n",
        "    self._features = features\n",
        "\n",
        "    linears = [MeanLayer()]\n",
        "    for i in range(1, self._features):\n",
        "      linears.append(nn.Linear(in_features=i, out_features=1))\n",
        "    self._linears = nn.ModuleList(linears)\n",
        "\n",
        "  def forward(self, x):\n",
        "    output = [self._linears[0](x[:, :1])]\n",
        "    for i in range(1, self._features):\n",
        "      output.append(self._linears[i](x[:, :i]))\n",
        "    output = torch.stack(output, axis=1).squeeze(dim=-1)\n",
        "    return torch.sigmoid(output)\n",
        "\n",
        "  def sample(self, conditioned_on=None):\n",
        "    with torch.no_grad():\n",
        "      if conditioned_on is None:\n",
        "        device = next(self.parameters()).device\n",
        "        conditioned_on = (torch.ones((1, self._features)) * - 1).to(device)\n",
        "      else:\n",
        "        conditioned_on = conditioned_on.clone()\n",
        "\n",
        "      for dim in range(self._features):\n",
        "        # TODO(eugenhotaj): We're wasting massive amounts of computation here\n",
        "        # by doing a full forward pass through all models for each dimension. \n",
        "        # We can instead do the same thing we did in forward.\n",
        "        out = self.forward(conditioned_on)[:, dim]\n",
        "        out = distributions.Bernoulli(probs=out).sample()\n",
        "        conditioned_on[:, dim] = torch.where(\n",
        "            conditioned_on[:, dim] < 0, out, conditioned_on[:, dim])\n",
        "      return conditioned_on "
      ],
      "execution_count": 114,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3vjA5_LCQ1qK",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 867
        },
        "outputId": "12acc9f0-6365-4e6d-eb49-cd00d8749947"
      },
      "source": [
        "INPUT_SIZE = 784\n",
        "N_EPOCHS = 50\n",
        "\n",
        "model = BagOfLogistics(INPUT_SIZE,).to(colab_utils.get_device())\n",
        "# TODO(eugenhotaj): Since all the underlying logistic models are disjoint, it\n",
        "# would be more optimal to have one optimizer per model. However, this adds \n",
        "# quite a bit of complexity to the training process. Using adaptive optimizers\n",
        "# also mitigates this issue someone since each parameter will get its own \n",
        "# learning rate.\n",
        "optimizer = optim.Adam(model.parameters())\n",
        "bce_loss_fn = nn.BCELoss(reduction='none')\n",
        "loss_fn = lambda x, y, preds: bce_loss_fn(preds, x).sum(dim=1).mean()\n",
        "\n",
        "train_losses, eval_losses = colab_utils.train_andor_evaluate(\n",
        "    model, \n",
        "    loss_fn, \n",
        "    optimizer=optimizer, \n",
        "    n_epochs=N_EPOCHS, \n",
        "    train_loader=trainloader,\n",
        "    eval_loader=testloader,\n",
        "    device=colab_utils.get_device())"
      ],
      "execution_count": 116,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[1|1796]: train_loss=342.28421095377604 eval_loss=282.7946029296875\n",
            "[2|1805]: train_loss=266.21285200195314 eval_loss=249.74182404785157\n",
            "[3|1795]: train_loss=240.80256975097657 eval_loss=229.90232705078125\n",
            "[4|1792]: train_loss=223.7373075846354 eval_loss=215.20296689453124\n",
            "[5|1808]: train_loss=210.46316151529948 eval_loss=203.44412036132812\n",
            "[6|1779]: train_loss=199.59648662109376 eval_loss=193.46713591308594\n",
            "[7|1794]: train_loss=190.3882016845703 eval_loss=185.00503876953124\n",
            "[8|1797]: train_loss=182.2994200439453 eval_loss=177.5134095703125\n",
            "[9|1801]: train_loss=175.2721121500651 eval_loss=170.9085322265625\n",
            "[10|1799]: train_loss=168.98715891927083 eval_loss=165.08293762207032\n",
            "[11|1795]: train_loss=163.3683205810547 eval_loss=159.72227004394531\n",
            "[12|1801]: train_loss=158.29717920735678 eval_loss=154.98139685058592\n",
            "[13|1798]: train_loss=153.7454821126302 eval_loss=150.67543576660157\n",
            "[14|1797]: train_loss=149.59134009602866 eval_loss=146.6785751953125\n",
            "[15|1788]: train_loss=145.83826706542968 eval_loss=143.10402414550782\n",
            "[16|1649]: train_loss=142.38370578613282 eval_loss=139.8206855957031\n",
            "[17|1792]: train_loss=139.2319891357422 eval_loss=136.8583934082031\n",
            "[18|1795]: train_loss=136.35108859049478 eval_loss=134.14040849609376\n",
            "[19|1777]: train_loss=133.73688153483073 eval_loss=131.6621201171875\n",
            "[20|1767]: train_loss=131.262961710612 eval_loss=129.23270563964843\n",
            "[21|1788]: train_loss=129.02052673339844 eval_loss=127.16640227050782\n",
            "[22|1787]: train_loss=127.04803903808593 eval_loss=125.11660073242187\n",
            "[23|1782]: train_loss=125.08450349121094 eval_loss=123.36581760253907\n",
            "[24|1787]: train_loss=123.34092176920574 eval_loss=121.71992117919922\n",
            "[25|1767]: train_loss=121.6744043334961 eval_loss=120.0763359741211\n",
            "[26|1799]: train_loss=120.16249693603515 eval_loss=118.62398585205078\n",
            "[27|1779]: train_loss=118.78080275065105 eval_loss=117.41100322265625\n",
            "[28|1757]: train_loss=117.43865423583985 eval_loss=116.07924448242187\n",
            "[29|1771]: train_loss=116.24580759684245 eval_loss=114.81695620117188\n",
            "[30|1763]: train_loss=115.10699691975911 eval_loss=113.84051746826172\n",
            "[31|1771]: train_loss=114.06843733317058 eval_loss=112.80711424560548\n",
            "[32|1759]: train_loss=113.0862426920573 eval_loss=111.85753126220703\n",
            "[33|1621]: train_loss=112.14771144612631 eval_loss=110.9456587524414\n",
            "[34|1782]: train_loss=111.35062067464193 eval_loss=110.22167940673828\n",
            "[35|1778]: train_loss=110.50956518554688 eval_loss=109.42169688720703\n",
            "[36|1781]: train_loss=109.74685019938151 eval_loss=108.7435497680664\n",
            "[37|1769]: train_loss=109.0879400349935 eval_loss=107.96806571044922\n",
            "[38|1758]: train_loss=108.39528197835287 eval_loss=107.40393327636718\n",
            "[39|1779]: train_loss=107.85278912353516 eval_loss=106.87200473632812\n",
            "[40|1780]: train_loss=107.27351559651693 eval_loss=106.13585755615235\n",
            "[41|1779]: train_loss=106.70187975260417 eval_loss=105.75408581542969\n",
            "[42|1781]: train_loss=106.18066407470702 eval_loss=105.23093256835938\n",
            "[43|1769]: train_loss=105.69638138020834 eval_loss=104.68659718017578\n",
            "[44|1758]: train_loss=105.2286335164388 eval_loss=104.26955616455078\n",
            "[45|1747]: train_loss=104.83399032796224 eval_loss=103.97493248291016\n",
            "[46|1766]: train_loss=104.47638867594401 eval_loss=103.53803442382812\n",
            "[47|1772]: train_loss=104.08714305419922 eval_loss=103.25254678955078\n",
            "[48|1772]: train_loss=103.68601694335938 eval_loss=102.87341743164062\n",
            "[49|1778]: train_loss=103.38597381591796 eval_loss=102.56980594482422\n",
            "[50|1697]: train_loss=103.06110318603515 eval_loss=102.26977041015626\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TA5X3FjiPT0k",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "51f30989-af09-4c00-95a6-13505b7f82d7"
      },
      "source": [
        "model_weights = model.state_dict()\n",
        "model = BagOfLogistics(INPUT_SIZE,).to(colab_utils.get_device())\n",
        "model.load_state_dict(model_weights)"
      ],
      "execution_count": 118,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<All keys matched successfully>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 118
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "nJQdRlN6gd-x",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "a = model.sample()"
      ],
      "execution_count": 120,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "A2EpXfyIpEEO",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 248
        },
        "outputId": "70012daf-8598-45c9-d370-0b1c81ae9f3e"
      },
      "source": [
        "colab_utils.imshow(a.reshape(28, 28))"
      ],
      "execution_count": 123,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAADn0lEQVR4nO3dUWrUYBiGUae4ClfhJsQVuEpXUNxEV9FlmN46EJKpyd88Tc65tFCi8PCBL5m5TdP0Beh5OvoBgHnihChxQpQ4IUqcEPV16Yc/nn4d9l+5z68viz//+e37Bz0JjPXn7+/b3J+7nBAlTogSJ0SJE6LECVHihChxQtTiznkkOyZX53JClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6Kyr4zBv674UakuJ0SJE6LECVHihChxQpQ4IUqcEGXn5GFrW+OaLVvkGXfMNS4nRIkTosQJUeKEKHFClDghSpwQZefkztYtc8vvvuKWucTlhChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEeWVsgPKrUSNfCdtqy7Od8XUzlxOixAlR4oQocUKUOCFKnBAlToiycw6wtrkt7Xln3Ov4Py4nRIkTosQJUeKEKHFClDghSpwQZec8wMgt88j3NW20+3I5IUqcECVOiBInRIkTosQJUeKEKDsnD7NjfiyXE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcEOUrAAd4fn05+hGG2Pr38hWC7+NyQpQ4IUqcECVOiBInRIkTosQJUXbOi1nbGkdutGu/2w56z+WEKHFClDghSpwQJU6IEidEiROi7JwD2Ovm+Xd5H5cTosQJUeKEKHFClDghSpwQZUrhztLccdaP/KxyOSFKnBAlTogSJ0SJE6LECVHihCg7J3e2bJleCduXywlR4oQocUKUOCFKnBAlTogSJ0TZOdmNr/jbl8sJUeKEKHFClDghSpwQJU6IEidEXXLnvPIe57NnPw+XE6LECVHihChxQpQ4IUqcEHXJKWX0VGKuYA8uJ0SJE6LECVHihChxQpQ4IUqcEHXJnXPN1p1ybUf9rDvomV+lK3I5IUqcECVOiBInRIkTosQJUeKEKDvnAeyFPMLlhChxQpQ4IUqcECVOiBInRIkTouycM+yQFLicECVOiBInRIkTosQJUeKEKHFClDghSpwQJU6IEidEiROixAlR4oQocUKUOCFKnBAlTogSJ0SJE6LECVHihChxQpQ4IUqcECVOiBInRIkTosQJUeKEqNs0TUc/AzDD5YQocUKUOCFKnBAlTogSJ0S9AcyJUMTTTlDDAAAAAElFTkSuQmCC\n",
            "text/plain": [
              "<Figure size 432x288 with 1 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    }
  ]
}